#include <stdio.h>
#include <stdlib.h>
#include <stddef.h>
#include <string.h>
#include <ctype.h>

// == values ==

enum tag {
    NIL, PAIR, SYMBOL, NUMBER,
    CONSTANT, LAMBDA, IFELSE, BEGIN, SETBANG,
};

typedef struct val* Val;

struct val {
    enum tag tag;
    union {
        const char *name;
        long int value;
    };
    Val field[0];
};

struct val nil = {.tag = NIL};

#define car(v_) (v_)->field[0]
#define cdr(v_) (v_)->field[1]
#define cadr(v_) car(cdr(v_))
#define cddr(v_) cdr(cdr(v_))
#define caddr(v_) car(cddr(v_))
#define cadddr(v_) cadr(cddr(v_))
#define list1(v_) cons(v_, &nil)
#define list2(v1_,v2_) cons(v1_, list1(v2_))
#define list3(v1_,v2_,v3_) cons(v1_, list2(v2_,v3_))

Val allocVal(enum tag tag, size_t n) {
    Val v = malloc(sizeof(struct val) + n * sizeof(Val));
    v->tag = tag;
    return v;
}

Val cons(Val hd, Val tl) {
    Val pair = allocVal(PAIR, 2);
    car(pair) = hd;
    cdr(pair) = tl;
    return pair;
}

Val makeSymbol(const char* name) {
    Val s = allocVal(SYMBOL, 0);
    s->name = name;
    return s;
}

Val makeNumber(long int value) {
    Val s = allocVal(NUMBER, 0);
    s->value = value;
    return s;
}

int symbolEq(Val v, const char *name) {
    return v->tag == SYMBOL && strcmp(v->name, name) == 0;
}

Val reverse(Val list) {
    Val rev = &nil;
    for (; list->tag == PAIR; list = cdr(list))
        rev = cons(car(list), rev);
    return rev;
}

// == reader ==

const char *source;

Val readSexp();

int describeChar(char c) {
    switch (c) {
    case ' ': case '\t': case '\r': case '\n': return 0;
    case '(': case ')': case '[': case ']': case '{': case '}':
    case '.': case '\'': case ';': case '\0': return 1;
    default: return 2;
    }
}

void skipWhitespace() {
    for (;;) {
        if (*source == ';')
            do source++; while (*source != '\n');
        if (describeChar(*source) == 0)
            source++;
        else
            break;
    }
}

char *readSymbol() {
    const char *start = source;
    while (describeChar(*source) == 2)
        source++;
    if (source == start)
        return NULL;
    return memcpy(calloc(source - start + 1, 1), start, source - start);
}

Val readSexpList(char eol) {
    skipWhitespace();
    if (*source == eol) {
        if (eol != '\0') source++;
        return &nil;
    }
    if (*source == '.') {
        source++;
        return readSexp();
    }
    Val hd = readSexp();
    Val tl = readSexpList(eol);
    return cons(hd, tl);
}

Val readSexp() {
    skipWhitespace();
    switch (*source) {
    case '(':
        source++;
        return readSexpList(')');
    case '[':
        source++;
        return readSexpList(']');
    case '{':
        source++;
        return readSexpList('}');
    case '\'':
        {
            source++;
            Val data = readSexp();
            static struct val quote = {.tag = SYMBOL, .name = "quote"};
            return list2(&quote, data);
        }
    }
    char *symbol = readSymbol(), *end;
    if (symbol == NULL) {
        char who[4] = "EOF";
        if (*source) {
            who[1] = *source;
            who[0] = who[2] = '\'';
        }
        fprintf(stderr, "error: unexpected %s\n", who);
        exit(1);
    }
    long int value = strtol(symbol, &end, 10);
    if (end != symbol)
        return makeNumber(value);
    return makeSymbol(symbol);
}

// == expander ==

#define DEFINE_MAP(MAP,F)                       \
    Val MAP(Val xs_) {                          \
        return (xs_->tag == NIL) ? xs_          \
            : cons(F(car(xs_)), MAP(cdr(xs_))); \
    }

DEFINE_MAP(cars, car);
DEFINE_MAP(cadrs, cadr);

Val expandExpr(Val sexp);

DEFINE_MAP(expandList, expandExpr);

Val expandDefinition(Val sexp, Val *names) {
    if (sexp->tag == PAIR) {
        if (symbolEq(car(sexp), "define")) {
            Val name = cadr(sexp);
            Val body = caddr(sexp);
            if (name->tag == PAIR) {
                // (define (f x ...) e)  ->  (define f (lambda (x ...) e))
                static struct val lambda = {.tag = SYMBOL, .name = "lambda"};
                body = cons(&lambda, cons(cdr(name), cddr(sexp)));
                name = car(name);
            }
            Val ast = allocVal(SETBANG, 2);
            ast->field[0] = name;
            ast->field[1] = expandExpr(body);
            *names = cons(name, *names);
            return ast;
        }
    }
    return expandExpr(sexp);
}

Val expandBlock(Val list) {
    if (list->tag != PAIR)
        return expandExpr(&nil);

    Val names = &nil;
    Val result = NULL;
    for (list = reverse(list); list->tag == PAIR; list = cdr(list)) {
        Val expr = expandDefinition(car(list), &names);
        if (result != NULL) {
            Val seq = allocVal(BEGIN, 2);
            seq->field[0] = expr;
            seq->field[1] = result;
            result = seq;
        } else
            result = expr;
    }

    if (names->tag == NIL)
        return result;
    
    // (define x e) ... b  ->  ((lambda (x ...) (set! x e) ... b) '() ...)
    Val args = &nil;
    Val initArg = allocVal(CONSTANT, 1);
    initArg->field[0] = &nil;
    for (list = names; list->tag == PAIR; list = cdr(list))
        args = cons(initArg, args);

    Val lambda = allocVal(LAMBDA, 2);
    lambda->field[0] = names;
    lambda->field[1] = result;
    return cons(lambda, args);
}

Val expandExpr(Val sexp) {
    if (sexp->tag == SYMBOL) {
        return sexp;
    } else if (sexp->tag != PAIR) {
        Val ast = allocVal(CONSTANT, 1);
        ast->field[0] = sexp;
        return ast;
    } else if (symbolEq(car(sexp), "let")) {
        static struct val lambda = {.tag = SYMBOL, .name = "lambda"};
        Val xs = cars(cadr(sexp));
        Val vs = cadrs(cadr(sexp));
        Val body = cddr(sexp);
        return expandExpr(cons(cons(&lambda, cons(xs, body)), vs));
    } else if (symbolEq(car(sexp), "let*")) {
        static struct val let = {.tag = SYMBOL, .name = "let"};
        static struct val begin = {.tag = SYMBOL, .name = "begin"};
        Val body = cons(&begin, cddr(sexp));
        for (Val vars = reverse(cadr(sexp)); vars->tag == PAIR; vars = cdr(vars))
            body = list3(&let, list1(car(vars)), body);
        return expandExpr(body);
    } else if (symbolEq(car(sexp), "quote")) {
        Val ast = allocVal(CONSTANT, 1);
        ast->field[0] = cadr(sexp);
        return ast;
    } else if (symbolEq(car(sexp), "lambda")) {
        Val ast = allocVal(LAMBDA, 2);
        ast->field[0] = cadr(sexp);
        ast->field[1] = expandBlock(cddr(sexp));
        return ast;
    } else if (symbolEq(car(sexp), "if")) {
        Val ast = allocVal(IFELSE, 3);
        ast->field[0] = expandExpr(cadr(sexp));
        ast->field[1] = expandExpr(caddr(sexp));
        ast->field[2] = expandExpr(cadddr(sexp));
        return ast;
    } else if (symbolEq(car(sexp), "begin")) {
        return expandBlock(cdr(sexp));
    } else if (symbolEq(car(sexp), "set!")) {
        Val ast = allocVal(SETBANG, 2);
        ast->field[0] = cadr(sexp);
        ast->field[1] = expandExpr(caddr(sexp));
        return ast;
    }
    return expandList(sexp);
}

// == binary format ==

char *outputBuf, *outputPtr;
size_t outputSize;

void initOutput() {
    outputBuf = outputPtr = malloc(outputSize = 16);
}

void writeByte(char c) {
    if ((size_t) (outputPtr - outputBuf) >= outputSize) {
        outputBuf = realloc(outputBuf, outputSize * 2);
        outputPtr = outputBuf + outputSize;
        outputSize *= 2;
    }
    *outputPtr++ = c;
}

void writeNBytes(const char *bs, size_t len) {
    for (size_t i = 0; i < len; i++)
        writeByte(bs[i]);
}

#define writeBytes(bs_) writeNBytes(bs_, sizeof(bs_) - 1)

void writeUint(unsigned long int n) {
    while (n >= 0x80) {
        writeByte((n & 0x7F) | 0x80);
        n >>= 7;
    }
    writeByte(n);
}

void writeSint(long int n) {
    while (n >= 0x40 || n <= -0x40) {
        writeByte((n & 0x7F) | 0x80);
        n >>= 7;
    }
    writeByte(n & 0x7F);
}

size_t beginLength() {
    size_t pos = outputPtr - outputBuf;
    writeBytes("\x80\x80\x80\x80\x00");
    return pos;
}

void endLength(size_t pos) {
    size_t len = outputPtr - outputBuf - pos - 5;
    while (len != 0) {
        outputBuf[pos++] |= len & 0x7F;
        len >>= 7;
    }
}

// == compiler ==

#define TYPE_SECTION 1
#define FUNC_SECTION 3
#define TABLE_SECTION 4
#define EXPORT_SECTION 7
#define ELEM_SECTION 9
#define CODE_SECTION 10

#define DROP "\x1A"
#define I32_CONST "\x41"
#define I31_REF "\xFB\x1C"
#define I31_ONE I32_CONST "\x01" I31_REF
#define CAST_I32 "\xFB\x16\x6C" "\xFB\x1D"
#define GET_ARGS "\x20\x00"
#define SET_ARGS "\x21\x00"
#define GET_ENV "\x20\x01"
#define SET_ENV "\x21\x01"
#define REF_NULL_ENV "\xD0\x01"
#define REF_NEW_ENV "\xFB\x00\x01"
#define REF_ENV_DROP "\xFB\x02\x01\x00"
#define REF_ENV_POP "\xFB\x02\x01\x01"
#define REF_NULL "\xD0\x6E"
#define REF_IS_NULL "\xD1"
#define REF_CONS "\xFB\x00\x00"
#define CAST_CONS "\xFB\x16\x00"
#define REF_CAR CAST_CONS "\xFB\x02\x00\x00"
#define REF_CDR CAST_CONS "\xFB\x02\x00\x01"
#define SET_CAR "\xFB\x05\x00\x00"
#define REF_NEW_PROC "\xFB\x00\x02"
#define REF_PROC_PREPARE_CALL \
    "\xFB\x16\x02" "\x22\x02" "\xFB\x02\x02\x00" "\x20\x02" "\xFB\x02\x02\x01"
#define CALL_INDIRECT "\x11\x03\x00"
#define RETURN_CALL_INDIRECT "\x13\x03\x00"
#define CALL "\x10"
#define IF "\x04\x6E"
#define ELSE "\x05"
#define END "\x0B"

struct func {
    size_t type;
    const char *expr;
    size_t exprLen;
};

struct func funcs[1024];
size_t funcsLen = 0;

size_t elems[1024];
size_t elemsLen = 0;

Val env[1024];
Val *envPtr = env;
int tail = 0;

void compileExpr(Val expr);

size_t compileProc(Val expr) {
    char *prevOutputPtr = outputPtr;
    char *prevOutputBuf = outputBuf;
    size_t prevOutputSize = outputSize;
    initOutput();
    // (local 2 (ref $ENV))
    writeBytes("\x01\x01\x64\x02");
    int prevTail = tail;
    tail = 0;
    writeBytes(GET_ENV GET_ARGS REF_NEW_ENV SET_ENV);
    compileExpr(expr);
    writeBytes(END);
    tail = prevTail;
    size_t funcidx = funcsLen++;
    funcs[funcidx].type = 3;
    funcs[funcidx].expr = outputBuf;
    funcs[funcidx].exprLen = outputPtr - outputBuf;
    outputPtr = prevOutputPtr;
    outputBuf = prevOutputBuf;
    outputSize = prevOutputSize;
    return funcidx;
}

void compileVar(const char *name) {
    Val *envPos = envPtr;
    writeBytes(GET_ENV);
    while (envPos > env) {
        Val params = *--envPos;
        for (size_t offset = 0; params->tag != NIL; offset++) {
            if (symbolEq(car(params), name)) {
                writeBytes(REF_ENV_POP);
                for (size_t i = 0; i < offset; i++)
                    writeBytes(REF_CDR);
                return;
            }
            params = cdr(params);
        }
        writeBytes(REF_ENV_DROP);
    }
    fprintf(stderr, "variable not found: %s\n", name);
    exit(1);
}

void compileConstant(Val data) {
    // TODO(?): symbol
    if (data->tag == NUMBER) {
        writeBytes(I32_CONST);
        writeSint(data->value);
        writeBytes(I31_REF);
    } else if (data->tag == PAIR) {
        compileConstant(car(data));
        compileConstant(cdr(data));
        writeBytes(REF_CONS);
    } else
        writeBytes(REF_NULL);
}

void compileList(Val list) {
    if (list->tag == PAIR) {
        compileExpr(car(list));
        compileList(cdr(list));
        writeBytes(REF_CONS);
    } else
        writeBytes(REF_NULL);
}

void compileExpr(Val expr) {
    if (expr->tag == CONSTANT) {
        compileConstant(expr->field[0]);
    } else if (expr->tag == SYMBOL) {
        compileVar(expr->name);
        writeBytes(REF_CAR);
    } else if (expr->tag == SETBANG) {
        compileVar(expr->field[0]->name);
        writeBytes(CAST_CONS);
        tail++;
        compileExpr(expr->field[1]);
        tail--;
        writeBytes(SET_CAR REF_NULL);
    } else if (expr->tag == LAMBDA) {
        *envPtr++ = expr->field[0];
        size_t funcidx = compileProc(expr->field[1]);
        envPtr--;
        size_t elemidx = elemsLen++;
        elems[elemidx] = funcidx;
        writeBytes(GET_ENV I32_CONST);
        writeSint(elemidx);
        writeBytes(REF_NEW_PROC);
    } else if (expr->tag == IFELSE) {
        tail++;
        compileExpr(expr->field[0]);
        tail--;
        writeBytes(REF_IS_NULL IF);
        compileExpr(expr->field[2]);
        writeBytes(ELSE);
        compileExpr(expr->field[1]);
        writeBytes(END);
    } else if (expr->tag == BEGIN) {
        tail++;
        compileExpr(expr->field[0]);
        tail--;
        writeBytes(DROP);
        compileExpr(expr->field[1]);
    } else {
        tail++;
        compileList(cdr(expr));
        compileExpr(car(expr));
        tail--;
        writeBytes(REF_PROC_PREPARE_CALL);
        if (tail == 0)
            writeBytes(RETURN_CALL_INDIRECT);
        else
            writeBytes(CALL_INDIRECT);
    }
}

// == builtin environment ==

void builtinN(const char *name, const char *expr, size_t exprLen) {
    size_t funcidx = funcsLen++;
    funcs[funcidx].type = 3;
    funcs[funcidx].expr = expr;
    funcs[funcidx].exprLen = exprLen;
    size_t elemidx = elemsLen++;
    elems[elemidx] = funcidx;
    writeBytes(REF_NULL_ENV I32_CONST); writeSint(elemidx);
    writeBytes(REF_NEW_PROC GET_ARGS REF_CONS SET_ARGS);
    *envPtr = cons(makeSymbol(name), *envPtr);
}

#define builtin(name_,expr_) builtinN(name_, expr_, sizeof(expr_) - 1)

void initBuiltins() {
#define CODE_INT_OP2(op_) "\0"                       \
        GET_ARGS REF_CAR CAST_I32               \
        GET_ARGS REF_CDR REF_CAR CAST_I32       \
        op_ I31_REF END

#define CODE_INT_CMP(op_) "\0"                       \
        GET_ARGS REF_CAR CAST_I32               \
        GET_ARGS REF_CDR REF_CAR CAST_I32       \
        op_ IF I31_ONE ELSE REF_NULL END END

    builtin("<", CODE_INT_CMP("\x48"));
    builtin(">", CODE_INT_CMP("\x4A"));
    builtin("=", CODE_INT_CMP("\x46"));
    builtin("*", CODE_INT_OP2("\x6C"));
    builtin("-", CODE_INT_OP2("\x6B"));
    builtin("+", CODE_INT_OP2("\x6A"));
    builtin("car", "\0" GET_ARGS REF_CAR REF_CAR END);
    builtin("cdr", "\0" GET_ARGS REF_CAR REF_CDR END);
    builtin("cons", "\0" GET_ARGS REF_CAR GET_ARGS REF_CDR REF_CAR REF_CONS END);
    builtin("null?", "\0" GET_ARGS REF_CAR REF_IS_NULL IF I31_ONE ELSE REF_NULL END END);
}

int main() {
    char sourceBuf[1048576];
    size_t sourceLen = fread(sourceBuf, 1, 1048575, stdin);
    sourceBuf[sourceLen] = '\0';
    source = sourceBuf;
    Val main = readSexpList('\0');
    main = expandBlock(main);

    size_t startFuncidx = funcsLen++;
    initOutput();
    // (local 0 anyref)
    writeBytes("\x01\x01\x6E");
    *envPtr = &nil;
    initBuiltins();
    envPtr++;
    writeBytes(GET_ARGS REF_NULL_ENV CALL);
    writeUint(compileProc(main));
    writeBytes(CAST_I32 END);
    funcs[startFuncidx].type = 4;
    funcs[startFuncidx].expr = outputBuf;
    funcs[startFuncidx].exprLen = outputPtr - outputBuf;

    size_t sectionLen;
    initOutput();
    writeBytes("\0asm\1\0\0\0");

    writeByte(TYPE_SECTION);
    sectionLen = beginLength();
    writeUint(5);
    // (type $PAIR 0 (struct (field mut anyref) (field mut anyref)))
    writeBytes("\x5F\x02\x6E\x01\x6E\x01");
    // (type $ENV 1 (struct (field (ref null $ENV)) (field anyref)))
    writeBytes("\x5F\x02\x63\x01\x00\x6E\x00");
    // (type $PROC 2 (struct (field (ref null $ENV)) (field i32)))
    writeBytes("\x5F\x02\x63\x01\x00\x7F\x00");
    // (type 3 (func (param anyref) (param (ref null $ENV)) (result anyref)))
    writeBytes("\x60\x02\x6E\x63\x01\x01\x6E");
    // (type 4 (func (result i32)))
    writeBytes("\x60\x00\x01\x7F");
    endLength(sectionLen);

    writeByte(FUNC_SECTION);
    sectionLen = beginLength();
    writeUint(funcsLen);
    for (size_t i = 0; i < funcsLen; i++)
        writeUint(funcs[i].type);
    endLength(sectionLen);

    writeByte(TABLE_SECTION);
    sectionLen = beginLength();
    // (table 0 <n> funcref)
    writeBytes("\x01\x70\x00");
    writeUint(elemsLen);
    endLength(sectionLen);

    writeByte(EXPORT_SECTION);
    sectionLen = beginLength();
    // (export "start" (func <i>))
    writeBytes("\x01\x05""start\x00");
    writeUint(startFuncidx);
    endLength(sectionLen);

    writeByte(ELEM_SECTION);
    sectionLen = beginLength();
    // (elem 0 (i32.const 0) func <i>*)
    writeBytes("\x01\x00\x41\x00\x0B");
    writeUint(elemsLen);
    for (size_t i = 0; i < elemsLen; i++)
        writeUint(elems[i]);
    endLength(sectionLen);

    writeByte(CODE_SECTION);
    sectionLen = beginLength();
    writeUint(funcsLen);
    for (size_t i = 0; i < funcsLen; i++) {
        writeUint(funcs[i].exprLen);
        writeNBytes(funcs[i].expr, funcs[i].exprLen);
    }
    endLength(sectionLen);

    char *writePtr = outputBuf;
    while (writePtr < outputPtr)
        writePtr += fwrite(writePtr, 1, outputPtr - writePtr, stdout);
}
